# Spectral GCN + Attention Recovery + LSTM
# This code trains and tests the GNN model for the COVID-19 infection prediction in Tokyo
# Author: Jiawei Xue, August 26, 2021
# Step 1: read and pack the traning and testing data
# Step 2: training epoch, training process, testing
# Step 3: build the model = spectral GCN + Attention Recovery + LSTM
# Step 4: main function
# Step 5: evaluation
# Step 6: visualization
import os
import csv
import json
import copy
import time
import random
import string
import argparse
import numpy as np
import pandas as pd
import geopandas as gpd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import torch.nn.functional as F
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN_LSTM
#torch.set_printoptions(precision=8)
#hyperparameter for the setting
X_day, Y_day = 21,7
START_DATE, END_DATE = '20200414','20210207'
#START_DATE, END_DATE = '20200808','20210603'
#START_DATE, END_DATE = '20200720','20210515'
WINDOW_SIZE = 7
#hyperparameter for the learning
DROPOUT, ALPHA = 0.50, 0.20
NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE = 100, 8, 0.0001
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2 = 6,4,3
infection_normalize_ratio = 100.0
web_search_normalize_ratio = 100.0
train_ratio = 0.7
validate_ratio = 0.1
SEED = 5
torch.manual_seed(SEED)
r = random.random
random.seed(5)
#1.total period (mobility+text):
#from 20200201 to 20210620: (29+31+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20)\
#= 335 + 171 = 506;
#2.number of zones: 23;
#3.infection period:
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#1. Mobility: functions 1.2 to 1.7
#2. Text: functions 1.8 to 1.14
#3. InfectionL: functions 1.15
#4. Preprocess: functions 1.16 to 1.24
#5. Learn: functions 1.25 to 1.26
#function 1.1
#get the central areas of Tokyo (e.g., the Special wards of Tokyo)
#return: a 23 zone shapefile
def read_tokyo_23():
folder = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/present_model_version10/tokyo_23"
file = "tokyo_23zones.shp"
path = os.path.join(folder,file)
data = gpd.read_file(path)
return data
##################1.Mobility#####################
#function 1.2
#get the average of two days' mobility (infection) records
def mob_inf_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = (record1[i]+record2[i])/2.0
return new_record
#function 1.3
#get the average of multiple days' mobility (infection) records
def mob_inf_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record:
if zone_id not in list(new_record.keys()):
new_record[zone_id] = record[zone_id]
else:
new_record[zone_id] += record[zone_id]
for new_record_key in new_record:
new_record[new_record_key] = new_record[new_record_key]*1.0/num_day
return new_record
#function 1.4
#generate the dateList: [20200101, 20200102, ..., 20211231]
def generate_dateList():
yearList = ["2020","2021"]
monthList = ["0"+str(i+1) for i in range(9)] + ["10","11","12"]
dayList = ["0"+str(i+1) for i in range(9)] + [str(i) for i in range(10,32)]
day_2020_num = [31,29,31,30,31,30,31,31,30,31,30,31]
day_2021_num = [31,28,31,30,31,30,31,31,30,31,30,31]
date_2020, date_2021 = list(), list()
for i in range(12):
for j in range(day_2020_num[i]):
date_2020.append(yearList[0] + monthList[i] + dayList[j])
for j in range(day_2021_num[i]):
date_2021.append(yearList[1] + monthList[i] + dayList[j])
date_2020_2021 = date_2020 + date_2021
return date_2020_2021
#function 1.5
#smooth the mobility (infection) data using the neighborhood average
#under a given window size
#dateList: [20200101, 20200102, ..., 20211231]
def mob_inf_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = mob_inf_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.6
#set the mobility (infection) of one day as zero
def mob_inf_average_null(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = 0
return new_record
#function 1.7
#read the mobility data from "mobility_feature_20200201.json"...
#return: all_mobility:{"20200201":{('123','123'):12345,...},...}
#20200201 to 20210620: 506 days
def read_mobility_data(jcode23):
all_mobility = dict()
mobilityFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/mobility_20210804"
mobilityNameList = os.listdir(mobilityFilePath)
for i in range(len(mobilityNameList)):
day_mobility = dict()
file_name = mobilityNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0] #get the day
file_path = mobilityFilePath + '/' + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
for key in df_file:
origin, dest = key.split("_")[0], key.split("_")[1]
if origin in jcode23 and dest in jcode23:
if origin == dest:
day_mobility[(origin, dest)] = 0.0 #ignore the inner-zone flow
else:
day_mobility[(origin, dest)] = df_file[key]
all_mobility[day] = day_mobility
#missing data
all_mobility["20201128"] = mob_inf_average(all_mobility,"20201127","20201129")
all_mobility["20210104"] = mob_inf_average(all_mobility, "20210103","20210105")
return all_mobility
##################2.Text#####################
#function 1.8
#get the average of two days' infection records
def text_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
zone_record1, zone_record2 = record1[i], record2[i]
new_zone_record = dict()
for j in zone_record1:
if j in zone_record2:
new_zone_record[j] = (zone_record1[j] + zone_record2[j])/2.0
new_record[i] = new_zone_record
return new_record
#function 1.9
#get the average of multiple days' text records
def text_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record: #zone_id
if zone_id not in new_record:
new_record[zone_id] = dict()
for j in record[zone_id]: #symptom
if j not in new_record[zone_id]:
new_record[zone_id][j] = record[zone_id][j]
else:
new_record[zone_id][j] += record[zone_id][j]
for zone_id in new_record:
for j in new_record[zone_id]:
new_record[zone_id][j] = new_record[zone_id][j]*1.0/num_day
return new_record
#function 1.10
#smooth the text data using the neighborhood average
#under a given window size
def text_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = text_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.11
#read the number of user points
def read_point_json():
with open('user_point/mobility_user_point.json') as point1:
user_point1 = json.load(point1)
with open('user_point/mobility_user_point_20210812.json') as point2:
user_point2 = json.load(point2)
user_point_all = dict()
for i in user_point1:
user_point_all[i] = user_point1[i]
for i in user_point2:
user_point_all[i] = user_point2[i]
user_point_all["20201128"] = user_point_all["20201127"] #data missing
user_point_all["20210104"] = user_point_all["20210103"] #data missing
return user_point_all
#function 1.12
#normalize the text search by the number of user points.
def normalize_text_user(all_text, user_point_all):
for day in all_text:
if day in user_point_all:
num_user = user_point_all[day]["num_user"]
all_text_day_new = dict()
all_text_day = all_text[day]
for zone in all_text_day:
if zone not in all_text_day_new:
all_text_day_new[zone] = dict()
for sym in all_text_day[zone]:
all_text_day_new[zone][sym] = all_text_day[zone][sym]*1.0/num_user
all_text[day] = all_text_day_new
return all_text
#function 1.13
#read the text data
#20200201 to 20210620: 506 days
#all_text = {"20200211":{"123":{"code":3,"fever":2,...},...},...}
def read_text_data(jcode23):
all_text = dict()
textFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/text_20210804"
textNameList = os.listdir(textFilePath)
for i in range(len(textNameList)):
day_text = dict()
file_name = textNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0]
file_path = textFilePath + "/" + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
new_dict = dict()
for key in df_file:
if key in jcode23:
new_dict[key] = {key1:df_file[key][key1]*1.0*web_search_normalize_ratio for key1 in df_file[key]}
#new_dict[key] = df_file[key]*WEB_SEARCH_RATIO
all_text[day] = new_dict
all_text["20201030"] = text_average(all_text, "20201029", "20201031") #data missing
return all_text
#function 1.14
#perform the min-max normalization for the text data.
def min_max_text_data(all_text,jcode23):
#calculate the min_max
#region_key: sym: [min,max]
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい', '吐き気', '嘔吐', '筋肉痛', '動悸', \
'副鼻腔炎', '発疹', 'くしゃみ', '倦怠感', '寒気', '脱水', \
'中咽頭', '関節痛', '不眠症', '睡眠障害', '鼻漏', '片頭痛', \
'多汗症', 'ほてり', '胸痛', '発汗', '無気力', '呼吸困難', \
'喘鳴', '目の痛み', '体の痛み', '無嗅覚症', '耳の痛み', \
'錯乱', '見当識障害', '胸の圧迫感', '鼻の乾燥', '耳感染症', \
'味覚消失', '上気道感染症', '眼感染症', '食欲減少'])
region_sym_min_max = dict()
for key in jcode23: #initialize
region_sym_min_max[key] = dict()
for sym in text_list:
region_sym_min_max[key][sym] = [1000000,0] #min, max
for day in all_text: #update
for key in jcode23:
for sym in text_list:
if sym in all_text[day][key]:
count = all_text[day][key][sym]
if count < region_sym_min_max[key][sym][0]:
region_sym_min_max[key][sym][0] = count
if count > region_sym_min_max[key][sym][1]:
region_sym_min_max[key][sym][1] = count
#print ("region_sym_min_max",region_sym_min_max)
for key in jcode23: #normalize
for sym in text_list:
min_count,max_count=region_sym_min_max[key][sym][0],region_sym_min_max[key][sym][1]
for day in all_text:
if sym in all_text[day][key]:
if max_count-min_count == 0:
all_text[day][key][sym] = 1
else:
all_text[day][key][sym] = (all_text[day][key][sym]-min_count)*1.0/(max_count-min_count)
#print("all_text[day][key][sym]",all_text[day][key][sym])
return all_text
##################3.Infection#####################
#function 1.15
#read the infection data
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#all_infection = {"20200201":{"123":1,"123":2}}
def read_infection_data(jcode23):
all_infection = dict()
infection_path = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/patient_20210725.json"
f = open(infection_path,)
df_file = json.load(f) #read the mobility file
f.close()
for zone_id in df_file:
for one_day in df_file[zone_id]:
daySplit = one_day.split("/")
year, month, day = daySplit[0], daySplit[1], daySplit[2]
if len(month) == 1:
month = "0" + month
if len(day) == 1:
day = "0" + day
new_date = year + month + day
if str(zone_id[0:5]) in jcode23:
if new_date not in all_infection:
all_infection[new_date] = {zone_id[0:5]:df_file[zone_id][one_day]*1.0/infection_normalize_ratio}
else:
all_infection[new_date][zone_id[0:5]] = df_file[zone_id][one_day]*1.0/infection_normalize_ratio
#missing
date_list = [str(20200316+i) for i in range(15)]
for date in date_list:
all_infection[date] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200514'] = mob_inf_average(all_infection,'20200513','20200515')
all_infection['20200519'] = mob_inf_average(all_infection,'20200518','20200520')
all_infection['20200523'] = mob_inf_average(all_infection,'20200522','20200524')
all_infection['20200530'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20200531'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20201231'] = mob_inf_average(all_infection,'20201230','20210101')
all_infection['20210611'] = mob_inf_average(all_infection,'20210610','20210612')
#outlier
all_infection['20200331'] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200910'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200911'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200511'] = mob_inf_average(all_infection,'20200510','20200512')
all_infection['20201208'] = mob_inf_average(all_infection,'20201207','20201209')
all_infection['20210208'] = mob_inf_average(all_infection,'20210207','20210209')
all_infection['20210214'] = mob_inf_average(all_infection,'20210213','20210215')
#calculate the subtraction
all_infection_subtraction = dict()
all_infection_subtraction['20200331'] = all_infection['20200331']
all_keys = list(all_infection.keys())
all_keys.sort()
for i in range(len(all_keys)-1):
record = dict()
for j in all_infection[all_keys[i+1]]:
record[j] = all_infection[all_keys[i+1]][j] - all_infection[all_keys[i]][j]
all_infection_subtraction[all_keys[i+1]] = record
return all_infection_subtraction, all_infection
##################4.Preprocess#####################
#function 1.16
#ensemble the mobility, text, and infection.
#all_mobility = {"20200201":{('123','123'):12345,...},...}
#all_text = {"20200201":{"123":{"cold":3,"fever":2,...},...},...}
#all_infection = {"20200316":{"123":1,"123":2}}
#all_x_y = {"0":[[mobility_1,text_1, ..., mobility_x_day,text_x_day], [infection_1,...,infection_y_day],\
#[infection_1,...,infection_x_day]],0}
#x_days, y_days: use x_days to predict y_days
def ensemble(all_mobility, all_text, all_infection, x_days, y_days, all_day_list):
all_x_y = dict()
for j in range(len(all_day_list) - x_days - y_days + 1):
x_sample, y_sample, x_sample_infection = list(), list(), list()
#add the data from all_day_list[0+j] to all_day_list[x_days-1+j]
for k in range(x_days):
day = all_day_list[k + j]
x_sample.append(all_mobility[day])
x_sample.append(all_text[day])
x_sample_infection.append(all_infection[day]) #concatenate with the infection data
#add the data from all_day_list[x_days+j] to all_day_list[x_days+y_day-1+j]
for k in range(y_days):
day = all_day_list[x_days + k + j]
y_sample.append(all_infection[day])
all_x_y[str(j)] = [x_sample, y_sample, x_sample_infection,j]
return all_x_y
#function 1.17
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_key = [all_x_y[str(i)] for i in range(n_train)]
validate_key = [all_x_y[str(i+n_train)] for i in range(n_validate)]
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key
##function 1.18
#the second data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_2(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train+n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if i % 9 == 8:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.19
#the third data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_3(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n - n_train - n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train + n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if (n_train + n_validate-i) % 2 == 0 and (n_train + n_validate-i) <= 2*n_validate:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.20
#find the mobility data starting from the day, which is x_days before the start_date
#start_date = "20200331", x_days = 7
def sort_date(all_mobility, start_date, x_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
mobility_date_cut = mobility_date_list[idx-x_days:]
return mobility_date_cut
#function 1.21
#find the mobility data starting from the day, which is x_days before the start_date,
#ending at the day, which is y_days after the end_date
#start_date = "20200331", x_days = 7
def sort_date_2(all_mobility, start_date, x_days, end_date, y_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
idx2 = mobility_date_list.index(end_date)
mobility_date_cut = mobility_date_list[idx-x_days:idx2+y_days]
return mobility_date_cut
#function 1.22
#get the mappings from zone id to id, text id to id.
#get zone_text_to_idx
def get_zone_text_to_idx(all_infection):
zone_list = list(set(all_infection["20200401"].keys()))
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい'])
zone_list.sort()
zone_dict = {str(zone_list[i]):i for i in range(len(zone_list))}
text_dict = {str(text_list[i]):i for i in range(len(text_list))}
return zone_dict, text_dict
#function 1.23
#change the data format to matrix
#zoneid_to_idx = {"13101":0, "13102":1, ..., "13102":22}
#sym_to_idx = {"cough":0}
#mobility: {('13101', '13101'): 709973, ...}
#text: {'13101': {'痛み': 51,...},...} text
#infection: {'13101': 50, '13102': 137, '13103': 401,...}
#data_type = {"mobility", "text", "infection"}
def to_matrix(zoneid_to_idx, sym_to_idx, input_data, data_type):
n_zone, n_text = len(zoneid_to_idx), len(sym_to_idx)
if data_type == "mobility":
result = np.zeros((n_zone, n_zone))
for key in input_data:
from_id, to_id = key[0], key[1]
from_idx, to_idx = zoneid_to_idx[from_id], zoneid_to_idx[to_id]
result[from_idx][to_idx] += input_data[key]
if data_type == "text":
result = np.zeros((n_zone, n_text))
for key1 in input_data:
for key2 in input_data[key1]:
if key1 in list(zoneid_to_idx.keys()) and key2 in list(sym_to_idx.keys()):
zone_idx, text_idx = zoneid_to_idx[key1], sym_to_idx[key2]
result[zone_idx][text_idx] += input_data[key1][key2]
if data_type == "infection":
result = np.zeros(n_zone)
for key in input_data:
zone_idx = zoneid_to_idx[key]
result[zone_idx] += input_data[key]
return result
#function 1.24
#change the data to the matrix format
def change_to_matrix(data, zoneid_to_idx, sym_to_idx):
data_result = list()
for i in range(len(data)):
combine1, combine2 = list(), list()
combine3 = list() #NEW
mobility_text = data[i][0]
x_infection_all = data[i][2] #the x_days infection data
day_order = data[i][3] #NEW the order of the day
for j in range(round(len(mobility_text)*1.0/2)):
mobility, text = mobility_text[2*j], mobility_text[2*j+1]
x_infection = x_infection_all[j] #NEW
new_mobility = to_matrix(zoneid_to_idx, sym_to_idx, mobility, "mobility")
new_text = to_matrix(zoneid_to_idx, sym_to_idx, text, "text")
combine1.append(new_mobility)
combine1.append(new_text)
new_x_infection = to_matrix(zoneid_to_idx, sym_to_idx, x_infection, "infection") #NEW
combine3.append(new_x_infection) #NEW
for j in range(len(data[i][1])):
infection = data[i][1][j]
new_infection = to_matrix(zoneid_to_idx, sym_to_idx, infection, "infection")
combine2.append(new_infection)
data_result.append([combine1,combine2,combine3,day_order]) #mobility/text; infection_y; infection_x; day_order
return data_result
##################5.learn#####################
#function 1.25
def visual_loss(e_losses, vali_loss, test_loss):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1,y2,y3 = copy.copy(e_losses), copy.copy(vali_loss), copy.copy(test_loss)
plt.plot(x,y1,linewidth=1, label="train")
plt.plot(x,y2,linewidth=1, label="validate")
plt.plot(x,y3,linewidth=1, label="test")
plt.legend()
plt.title('Loss decline on entire training/validation/testing data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 1.26
def visual_loss_train(e_losses):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1 = copy.copy(e_losses)
plt.plot(x,y1,linewidth=1, label="train")
plt.legend()
plt.title('Loss decline on entire training data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 2.1
#normalize each column of the input mobility matrix as one
def normalize_column_one(input_matrix):
column_sum = np.sum(input_matrix, axis=0)
row_num, column_num = len(input_matrix), len(input_matrix[0])
for i in range(row_num):
for j in range(column_num):
input_matrix[i][j] = input_matrix[i][j]*1.0/column_sum[j]
return input_matrix
#function 2.2
#evalute the trained_model on validation or testing data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x)
loss = criterion(y_hat.float(), y_real.float()) ###Calculate the loss
return loss, y_hat, y_real
#function 2.3
#convert the mobility matrix in x_batch in a following way
#normalize the flow between zones so that the in-flow of each zone is 1.
def convertAdj(x_batch):
#x_batch:(n_batch, 0/1, 2*i+1)
x_batch_new = copy.copy(x_batch)
n_batch = len(x_batch)
days = round(len(x_batch[0][0])/2)
for i in range(n_batch):
for j in range(days):
mobility_matrix = x_batch[i][0][2*j]
x_batch_new[i][0][2*j] = normalize_column_one(mobility_matrix) #20210818
return x_batch_new
#function 2.4
#a training epoch
def train_epoch_option(model, opt, criterion, trainX_c, trainY_c, batch_size):
model.train()
losses = []
batch_num = 0
for beg_i in range(0, len(trainX_c), batch_size):
batch_num += 1
if batch_num % 16 ==0:
print ("batch_num: ", batch_num, "total batch number: ", int(len(trainX_c)/batch_size))
x_batch = trainX_c[beg_i:beg_i+batch_size]
y_batch = torch.tensor(trainY_c[beg_i:beg_i+batch_size])
opt.zero_grad()
x_batch = convertAdj(x_batch) #conduct the column normalization
y_hat = model.run_specGCN_lstm(x_batch) ###Attention
loss = criterion(y_hat.float(), y_batch.float()) #MSE loss
#opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.data.numpy())
return sum(losses)/float(len(losses)), model
#function 2.5
#multiple training epoch
def train_process(train_data, lr, num_epochs, net, criterion, bs, vali_data, test_data):
opt = optim.Adam(net.parameters(), lr, betas = (0.9,0.999), weight_decay=0)
train_y = [train_data[i][1] for i in range(len(train_data))]
e_losses = list()
e_losses_vali = list()
e_losses_test = list()
time00 = time.time()
for e in range(num_epochs):
time1 = time.time()
print ("current epoch: ",e, "total epoch: ", num_epochs)
number_list = list(range(len(train_data)))
random.shuffle(number_list, random = r)
trainX_sample = [train_data[number_list[j]] for j in range(len(number_list))]
trainY_sample = [train_y[number_list[j]] for j in range(len(number_list))]
loss, net = train_epoch_option(net, opt, criterion, trainX_sample, trainY_sample, bs)
print ("train loss", loss*infection_normalize_ratio*infection_normalize_ratio)
e_losses.append(loss*infection_normalize_ratio*infection_normalize_ratio)
loss_vali, y_hat_vali, y_real_vali = validate_test_process(net, vali_data)
loss_test, y_hat_test, y_real_test = validate_test_process(net, test_data)
e_losses_vali.append(float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
e_losses_test.append(float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
print ("validate loss", float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
print ("test loss", float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
if e>=2 and (e+1)%10 ==0:
visual_loss(e_losses, e_losses_vali, e_losses_test)
visual_loss_train(e_losses)
time2 = time.time()
print ("running time for this epoch:", time2 - time1)
time01 = time.time()
print ("---------------------------------------------------------------")
print ("---------------------------------------------------------------")
#print ("total running time until now:", time01 - time00)
#print ("------------------------------------------------")
#print("specGCN_weight", net.specGCN.layer1.W)
#print("specGCN_weight_grad", net.specGCN.layer1.W.grad)
#print ("------------------------------------------------")
#print("memory decay matrix", net.v)
#print("memory decay matrix grad", net.v.grad)
#print ("------------------------------------------------")
#print ("lstm weight", net.lstm.all_weights[0][0])
#print ("lstm weight grad", net.lstm.all_weights[0][0].grad)
#print ("------------------------------------------------")
#print ("fc1.weight", net.fc1.weight)
#print ("fc1 weight grd", net.fc1.weight.grad)
#print ("---------------------------------------------------------------")
#print ("---------------------------------------------------------------")
return e_losses, net
#function 3.1
def read_data():
jcode23 = list(read_tokyo_23()["JCODE"]) #1.1 get the tokyo 23 zone shapefile
all_mobility = read_mobility_data(jcode23) #1.2 read the mobility data
all_text = read_text_data(jcode23) #1.3 read the text data
all_infection, all_infection_cum = read_infection_data(jcode23) #1.4 read the infection data
#smooth the data using 7-days average
window_size = WINDOW_SIZE #20210818
dateList = generate_dateList() #20210818
all_mobility = mob_inf_smooth(all_mobility, window_size, dateList) #20210818
all_infection = mob_inf_smooth(all_infection, window_size, dateList) #20210818
#smooth, user, min-max.
point_json = read_point_json() #20210821
all_text = normalize_text_user(all_text, point_json) #20210821
all_text = text_smooth(all_text, window_size, dateList) #20210818
all_text = min_max_text_data(all_text,jcode23) #20210820
x_days, y_days = X_day, Y_day
mobility_date_cut = sort_date_2(all_mobility, START_DATE, x_days, END_DATE, y_days)
all_x_y = ensemble(all_mobility, all_text, all_infection, x_days, y_days, mobility_date_cut)
train_original, validate_original, test_original, train_list, validation_list =\
split_data_3(all_x_y,train_ratio,validate_ratio)
zone_dict, text_dict = get_zone_text_to_idx(all_infection) #get zone_idx, text_idx
train_x_y = change_to_matrix(train_original, zone_dict, text_dict) #get train
print ("train_x_y_shape",len(train_x_y),"train_x_y_shape[0]",len(train_x_y[0]))
validate_x_y = change_to_matrix(validate_original, zone_dict, text_dict) #get validate
test_x_y = change_to_matrix(test_original, zone_dict, text_dict) #get test
print (len(train_x_y)) #300
print (len(train_x_y[0][0])) #14
print (np.shape(train_x_y[0][0][0])) #(23,23)
print (np.shape(train_x_y[0][0][1])) #(23,43)
#print ("---------------------------------finish data reading and preprocessing------------------------------------")
return train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, train_original, validate_original, test_original, train_list, validation_list
#function 3.2
#train the model
def model_train(train_x_y, vali_data, test_data):
#3.2.1 define the model
input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2 = len(train_x_y[0][0][1][1]),\
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2
dropout_1, alpha_1, N = DROPOUT, ALPHA, len(train_x_y[0][0][1])
G_L_Model = SpecGCN_LSTM(X_day, Y_day, input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2, dropout_1,N) ###Attention
#3.2.2 train the model
num_epochs, batch_size, learning_rate = NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE #model train
criterion = nn.MSELoss()
e_losses, trained_model = train_process(train_x_y, learning_rate, num_epochs, G_L_Model, criterion, batch_size,\
vali_data, test_data)
return e_losses, trained_model
#function 3.3
#evaluate the error on validation (or testing) data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x) ###Attention
loss = criterion(y_hat.float(), y_real.float())
return loss, y_hat, y_real
#4.1
#read the data
train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, \
train_original, validate_original, test_original, train_list, validation_list =\
read_data()
#train_x_y, validate_x_y, test_x_y = normalize(train_x_y, validate_x_y, test_x_y)
#train_x_y = train_x_y[0:30]
print (len(train_x_y))
print ("---------------------------------finish data preparation------------------------------------")
train_x_y_shape 210 train_x_y_shape[0] 4 210 42 (23, 23) (23, 8) 210 ---------------------------------finish data preparation------------------------------------
#4.2
#train the model
e_losses, trained_model = model_train(train_x_y, validate_x_y, test_x_y)
print ("---------------------------finish model training-------------------------")
current epoch: 0 total epoch: 100 batch_num: 16 total batch number: 26 train loss 71.65465314217188 validate loss 88.20460177958012 test loss 1353.505253791809 running time for this epoch: 11.968618869781494 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 1 total epoch: 100 batch_num: 16 total batch number: 26 train loss 68.19081406488463 validate loss 85.92230267822742 test loss 1344.89506483078 running time for this epoch: 12.782521724700928 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 2 total epoch: 100 batch_num: 16 total batch number: 26 train loss 64.72199965544321 validate loss 83.02533067762852 test loss 1334.7132503986359 running time for this epoch: 13.002811908721924 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 3 total epoch: 100 batch_num: 16 total batch number: 26 train loss 62.60269131787397 validate loss 80.06644435226917 test loss 1315.4605031013489 running time for this epoch: 13.014333248138428 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 4 total epoch: 100 batch_num: 16 total batch number: 26 train loss 58.027826232352744 validate loss 75.34132804721594 test loss 1287.9087030887604 running time for this epoch: 13.493335008621216 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 5 total epoch: 100 batch_num: 16 total batch number: 26 train loss 55.68569500206245 validate loss 71.19998335838318 test loss 1260.1250410079956 running time for this epoch: 13.562705993652344 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 6 total epoch: 100 batch_num: 16 total batch number: 26 train loss 51.44356603354767 validate loss 66.84944033622742 test loss 1221.6425687074661 running time for this epoch: 13.632763624191284 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 7 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.28831922538855 validate loss 62.22906056791544 test loss 1180.957704782486 running time for this epoch: 13.728468656539917 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 8 total epoch: 100 batch_num: 16 total batch number: 26 train loss 45.53858578825991 validate loss 58.12108516693115 test loss 1144.5464193820953 running time for this epoch: 13.768554449081421 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 9 total epoch: 100 batch_num: 16 total batch number: 26 train loss 44.283320247713064 validate loss 55.221556685864925 test loss 1107.361614704132
running time for this epoch: 14.003780603408813 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 10 total epoch: 100 batch_num: 16 total batch number: 26 train loss 39.634489782016594 validate loss 49.46450237184763 test loss 1046.920269727707 running time for this epoch: 13.712234258651733 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 11 total epoch: 100 batch_num: 16 total batch number: 26 train loss 37.11443411669246 validate loss 45.58413755148649 test loss 994.156152009964 running time for this epoch: 13.875452995300293 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 12 total epoch: 100 batch_num: 16 total batch number: 26 train loss 35.12780979724118 validate loss 42.729247361421585 test loss 938.6391937732697 running time for this epoch: 13.921021461486816 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 13 total epoch: 100 batch_num: 16 total batch number: 26 train loss 31.36707042964796 validate loss 39.197541773319244 test loss 879.0314942598343 running time for this epoch: 13.815290689468384 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 14 total epoch: 100 batch_num: 16 total batch number: 26 train loss 29.215687127025042 validate loss 35.83092940971255 test loss 816.8730139732361 running time for this epoch: 13.901742219924927 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 15 total epoch: 100 batch_num: 16 total batch number: 26 train loss 27.267148769978018 validate loss 33.12363289296627 test loss 754.4982433319092 running time for this epoch: 14.002227067947388 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 16 total epoch: 100 batch_num: 16 total batch number: 26 train loss 24.078015854616687 validate loss 30.850062612444162 test loss 695.8729773759842 running time for this epoch: 13.898393392562866 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 17 total epoch: 100 batch_num: 16 total batch number: 26 train loss 22.830657587662614 validate loss 29.349022079259157 test loss 646.5724110603333 running time for this epoch: 13.929988145828247 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 18 total epoch: 100 batch_num: 16 total batch number: 26 train loss 20.905885461683347 validate loss 27.619872707873583 test loss 599.156990647316 running time for this epoch: 13.996757984161377 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 19 total epoch: 100 batch_num: 16 total batch number: 26 train loss 20.293258992886106 validate loss 25.889063253998756 test loss 556.8546429276466
running time for this epoch: 14.301079034805298 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 20 total epoch: 100 batch_num: 16 total batch number: 26 train loss 18.24168259864328 validate loss 24.253984447568655 test loss 514.3410339951515 running time for this epoch: 13.998934030532837 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 21 total epoch: 100 batch_num: 16 total batch number: 26 train loss 17.594040944930857 validate loss 23.07030139490962 test loss 477.8819531202316 running time for this epoch: 14.125396013259888 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 22 total epoch: 100 batch_num: 16 total batch number: 26 train loss 16.36689155862701 validate loss 21.264967508614063 test loss 441.5815696120262 running time for this epoch: 14.220662593841553 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 23 total epoch: 100 batch_num: 16 total batch number: 26 train loss 15.78056134921671 validate loss 20.5311831086874 test loss 415.73941707611084 running time for this epoch: 14.144562721252441 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 24 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.49538327977751 validate loss 19.093750743195415 test loss 386.90879940986633 running time for this epoch: 14.233512878417969 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 25 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.150575288347211 validate loss 18.279938958585262 test loss 361.9248792529106 running time for this epoch: 14.289520502090454 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 26 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.934179782556992 validate loss 17.9033144377172 test loss 344.5354104042053 running time for this epoch: 14.507002353668213 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 27 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.985527127291316 validate loss 16.810032539069653 test loss 326.6490623354912 running time for this epoch: 14.394442081451416 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 28 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.604693287086706 validate loss 15.956969000399113 test loss 309.7432106733322 running time for this epoch: 14.38582968711853 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 29 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.111715554121744 validate loss 15.658818883821368 test loss 295.939389616251
running time for this epoch: 14.76377248764038 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 30 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.515472968108952 validate loss 15.10151894763112 test loss 283.5726924240589 running time for this epoch: 14.442068576812744 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 31 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.19197235884214 validate loss 14.731634873896837 test loss 270.4204060137272 running time for this epoch: 14.476494073867798 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 32 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.692096177281604 validate loss 14.201872982084751 test loss 262.5894919037819 running time for this epoch: 14.378684282302856 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 33 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.701478562352282 validate loss 13.702780706807971 test loss 253.15765291452408 running time for this epoch: 14.456891536712646 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 34 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.244521131256112 validate loss 12.952745892107487 test loss 242.5854094326496 running time for this epoch: 14.416181564331055 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 35 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.965616439969313 validate loss 12.834476074203849 test loss 238.3679710328579 running time for this epoch: 14.470522403717041 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 36 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.694404289902499 validate loss 12.50764005817473 test loss 232.0474572479725 running time for this epoch: 14.647277355194092 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 37 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.668205485230795 validate loss 11.999168200418353 test loss 224.98033940792084 running time for this epoch: 14.647843360900879 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 38 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.321651570171255 validate loss 11.65826222859323 test loss 220.04257887601852 running time for this epoch: 14.534273624420166 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 39 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.51814977876428 validate loss 11.702649062499404 test loss 215.8590778708458
running time for this epoch: 15.050745487213135 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 40 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.363875741552976 validate loss 11.023118859156966 test loss 210.60790866613388 running time for this epoch: 14.497367143630981 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 41 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.9977796175862075 validate loss 10.90969075448811 test loss 207.16892555356026 running time for this epoch: 14.828944206237793 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 42 total epoch: 100 batch_num: 16 total batch number: 26 train loss 8.125898635221853 validate loss 10.865594958886504 test loss 203.922800719738 running time for this epoch: 14.874943971633911 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 43 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.84033740206656 validate loss 10.416015284135938 test loss 202.5303617119789 running time for this epoch: 14.592009544372559 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 44 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.745213278448554 validate loss 10.361273307353258 test loss 199.91060718894005 running time for this epoch: 14.525280475616455 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 45 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.643431664392766 validate loss 10.130838491022587 test loss 197.72399216890335 running time for this epoch: 14.612599611282349 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 46 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.520358614985727 validate loss 9.881611913442612 test loss 195.4871602356434 running time for this epoch: 14.700332164764404 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 47 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.4859044656881855 validate loss 9.869618806988 test loss 193.47060471773148 running time for this epoch: 14.723299980163574 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 48 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.403459695288567 validate loss 9.818554390221834 test loss 192.4200914800167 running time for this epoch: 14.728524446487427 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 49 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.435375408493671 validate loss 9.757784428074956 test loss 190.5597560107708
running time for this epoch: 15.014400959014893 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 50 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.247110210031408 validate loss 9.371046908199787 test loss 189.29257988929749 running time for this epoch: 14.488887548446655 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 51 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.379438978171458 validate loss 9.17702738661319 test loss 186.72773614525795 running time for this epoch: 14.645267724990845 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 52 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.141035727087271 validate loss 9.181591449305415 test loss 186.72559410333633 running time for this epoch: 14.725844383239746 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 53 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.19559631155183 validate loss 9.24627238418907 test loss 185.27477979660034 running time for this epoch: 14.649362564086914 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 54 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.11372747361936 validate loss 9.157133172266185 test loss 184.8980411887169 running time for this epoch: 14.925136804580688 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 55 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.132840282456191 validate loss 8.95216770004481 test loss 183.68961289525032 running time for this epoch: 14.53675889968872 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 56 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.213563194791614 validate loss 8.913527126424015 test loss 182.2996698319912 running time for this epoch: 14.576252698898315 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 57 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.178213682318865 validate loss 9.052233071997762 test loss 182.58826807141304 running time for this epoch: 14.607437133789062 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 58 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.906087337820618 validate loss 8.610098739154637 test loss 180.65376207232475 running time for this epoch: 14.523469924926758 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 59 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.2692677107018735 validate loss 8.779950439929962 test loss 180.62129616737366
running time for this epoch: 14.891548871994019 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 60 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.920553989604943 validate loss 8.533064392395318 test loss 178.69507893919945 running time for this epoch: 14.623005628585815 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 61 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.839968247934141 validate loss 8.640640880912542 test loss 178.52561548352242 running time for this epoch: 14.576404333114624 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 62 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.911594127477319 validate loss 8.501603733748198 test loss 178.4648932516575 running time for this epoch: 14.593591213226318 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 63 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.7893418222579545 validate loss 8.412536699324846 test loss 177.00623720884323 running time for this epoch: 14.749491214752197 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 64 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.730046443408355 validate loss 8.458676747977734 test loss 177.00841650366783 running time for this epoch: 14.401713848114014 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 65 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.949333236897709 validate loss 8.321666391566396 test loss 176.31029710173607 running time for this epoch: 14.286027669906616 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 66 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.783098840430656 validate loss 8.253277628682554 test loss 174.8581975698471 running time for this epoch: 14.487047672271729 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 67 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.744098090739162 validate loss 8.236899157054722 test loss 174.40197989344597 running time for this epoch: 14.410486936569214 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 68 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.990629233230182 validate loss 8.15249397419393 test loss 174.00536686182022 running time for this epoch: 14.556658506393433 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 69 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.7111252725158845 validate loss 8.100805571302772 test loss 173.64121973514557
running time for this epoch: 14.909847974777222 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 70 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.67278237718468 validate loss 7.872174610383809 test loss 171.99000343680382 running time for this epoch: 14.466550827026367 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 71 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.955244656031331 validate loss 7.926569669507444 test loss 171.87459394335747 running time for this epoch: 14.445257663726807 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 72 total epoch: 100 batch_num: 16 total batch number: 26 train loss 7.0970608956284 validate loss 8.006246644072235 test loss 171.98864370584488 running time for this epoch: 14.483021259307861 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 73 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.713255026153531 validate loss 7.961588562466204 test loss 170.53121700882912 running time for this epoch: 14.298417806625366 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 74 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.770854169088933 validate loss 7.8259449219331145 test loss 169.36689615249634 running time for this epoch: 14.38095211982727 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 75 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.812188636373591 validate loss 7.809505914337933 test loss 169.33340579271317 running time for this epoch: 14.389676570892334 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 76 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.63254670885012 validate loss 7.612222689203918 test loss 168.702881783247 running time for this epoch: 14.406091690063477 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 77 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.673343132543205 validate loss 7.990084122866392 test loss 168.55057328939438 running time for this epoch: 14.756046772003174 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 78 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.801943775456122 validate loss 7.8087428119033575 test loss 167.57525503635406 running time for this epoch: 14.60866904258728 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 79 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.486014473786647 validate loss 7.751168450340629 test loss 167.26002097129822
running time for this epoch: 14.843711376190186 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 80 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.548459135013183 validate loss 7.789891096763313 test loss 166.542399674654 running time for this epoch: 14.754580974578857 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 81 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.609867699875462 validate loss 7.702081929892302 test loss 165.92729836702347 running time for this epoch: 14.501964330673218 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 82 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.6031376977830565 validate loss 7.836393779143691 test loss 165.65022990107536 running time for this epoch: 14.446448802947998 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 83 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.49624418993308 validate loss 7.7281612902879715 test loss 164.83604907989502 running time for this epoch: 14.512712955474854 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 84 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.444410849966246 validate loss 7.777889841236174 test loss 165.2667485177517 running time for this epoch: 14.674900770187378 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 85 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.403224682866561 validate loss 7.686576573178172 test loss 163.77799212932587 running time for this epoch: 14.465677976608276 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 86 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.7676797819634285 validate loss 7.560672820545733 test loss 163.69618475437164 running time for this epoch: 14.468178749084473 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 87 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.638774583633575 validate loss 7.547541172243655 test loss 162.82275319099426 running time for this epoch: 14.653318881988525 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 88 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.363561573110658 validate loss 7.456736639142036 test loss 162.577535957098 running time for this epoch: 14.522071838378906 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 89 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.570358711070623 validate loss 7.477757171727717 test loss 162.05700114369392
running time for this epoch: 14.953892230987549 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 90 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.3535397661056505 validate loss 7.5029145227745175 test loss 162.21804544329643 running time for this epoch: 14.62250804901123 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 91 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.432844513458096 validate loss 7.651079213246703 test loss 161.3777130842209 running time for this epoch: 14.480579137802124 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 92 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.315188107287718 validate loss 7.418474415317178 test loss 161.27048060297966 running time for this epoch: 14.562636375427246 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 93 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.383409954314293 validate loss 7.371259853243828 test loss 160.76186671853065 running time for this epoch: 14.517831325531006 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 94 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.491834539230224 validate loss 7.422868511639535 test loss 160.06624326109886 running time for this epoch: 14.426955938339233 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 95 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.478107286427447 validate loss 7.4464199133217335 test loss 159.32079404592514 running time for this epoch: 14.480238914489746 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 96 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.314443933122137 validate loss 7.2944501880556345 test loss 158.9052751660347 running time for this epoch: 14.559489488601685 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 97 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.298800430226105 validate loss 7.393264095298946 test loss 158.34640711545944 running time for this epoch: 14.414538621902466 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 98 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.4267972441114205 validate loss 7.42489006370306 test loss 158.17219391465187 running time for this epoch: 14.498798847198486 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 99 total epoch: 100 batch_num: 16 total batch number: 26 train loss 6.361493648943941 validate loss 7.341393502429128 test loss 157.6458290219307
running time for this epoch: 15.10403323173523 --------------------------------------------------------------- --------------------------------------------------------------- ---------------------------finish model training-------------------------
#4.3
print (len(train_x_y))
print (len(validate_x_y))
print (len(test_x_y))
#4.3.1 model validation
validation_result, validate_hat, validate_real = validate_test_process(trained_model, validate_x_y)
print ("---------------------------------finish model validation------------------------------------")
print (len(validate_hat))
print (len(validate_real))
#4.3.2 model testing
#4.4. model test
test_result, test_hat, test_real = validate_test_process(trained_model, test_x_y)
print ("---------------------------------finish model testing------------------------------------")
print (len(test_real))
print (len(test_hat))
210 30 60 ---------------------------------finish model validation------------------------------------ 30 30 ---------------------------------finish model testing------------------------------------ 60 60
#5.1 RMSE, MAPE, MAE, RMSLE
def RMSELoss(yhat,y):
return float(torch.sqrt(torch.mean((yhat-y)**2)))
def MAPELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), y)))
def MAELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), 1)))
def RMSLELoss(yhat,y):
log_yhat = torch.log(yhat+1)
log_y = torch.log(y+1)
return float(torch.sqrt(torch.mean((log_yhat-log_y)**2)))
#compute RMSE
rmse_validate = list()
rmse_test = list()
for i in range(len(validate_x_y)):
rmse_validate.append(float(RMSELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
rmse_test.append(float(RMSELoss(test_hat[i],test_real[i])))
print ("rmse_validate mean", np.mean(rmse_validate))
print ("rmse_test mean", np.mean(rmse_test))
#compute MAE
mae_validate = list()
mae_test = list()
for i in range(len(validate_x_y)):
mae_validate.append(float(MAELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
mae_test.append(float(MAELoss(test_hat[i],test_real[i])))
print ("mae_validate mean", np.mean(mae_validate))
print ("mae_test mean", np.mean(mae_test))
#show RMSE and MAE together
mae_validate, rmse_validate, mae_test, rmse_test =\
np.array(mae_validate)*infection_normalize_ratio, np.array(rmse_validate)*infection_normalize_ratio,\
np.array(mae_test)*infection_normalize_ratio, np.array(rmse_test)*infection_normalize_ratio
print ("-----------------------------------------")
print ("mae_validate mean", round(np.mean(mae_validate),3), " rmse_validate mean", round(np.mean(rmse_validate),3))
print ("mae_test mean", round(np.mean(mae_test),3), " rmse_test mean", round(np.mean(rmse_test),3))
print ("-----------------------------------------")
rmse_validate mean 0.025301775607526427 rmse_test mean 0.10426167128517119 mae_validate mean 0.01962121918276547 mae_test mean 0.0802888042130936 ----------------------------------------- mae_validate mean 1.962 rmse_validate mean 2.53 mae_test mean 8.029 rmse_test mean 10.426 -----------------------------------------
print (trained_model.v)
Parameter containing:
tensor([0.0896, 0.0401, 0.0047, 0.0011, 0.0476, 0.0411, 0.0340, 0.0117, 0.0039,
0.0221, 0.0019, 0.0008, 0.0130, 0.0080, 0.0060, 0.0236, 0.0397, 0.0439,
0.0085, 0.0016, 0.0014, 0.0253, 0.0050], requires_grad=True)
print(validate_hat[0][Y_day-1])
print(torch.sum(validate_hat[0][Y_day-1]))
print(validate_real[0][Y_day-1])
print(torch.sum(validate_real[0][Y_day-1]))
tensor([0.0000, 0.0298, 0.0725, 0.0629, 0.0190, 0.0343, 0.0443, 0.0627, 0.0915,
0.0769, 0.1305, 0.1486, 0.0458, 0.0635, 0.0618, 0.0358, 0.0408, 0.0324,
0.0499, 0.0753, 0.0933, 0.0644, 0.0864], grad_fn=<SelectBackward>)
tensor(1.4222, grad_fn=<SumBackward0>)
tensor([0.0100, 0.0357, 0.0771, 0.1000, 0.0243, 0.0371, 0.0300, 0.0329, 0.0671,
0.0443, 0.1800, 0.1643, 0.0514, 0.0700, 0.0743, 0.0286, 0.0443, 0.0343,
0.0543, 0.0757, 0.1014, 0.0457, 0.0486], dtype=torch.float64)
tensor(1.4314, dtype=torch.float64)
x = range(len(rmse_validate))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_validate), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_validate), 'go-',linewidth=0.8, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of validation',fontsize=12)
plt.ylabel("RMSE/MAE daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
x = range(len(mae_test))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_test), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_test), 'go-',linewidth=0.5, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of test',fontsize=12)
plt.ylabel("RMSE/MAE Daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
from scipy import stats
#validate
y_days = Y_day
validate_hat_sum = [float(torch.sum(validate_hat[i][y_days-1])) for i in range(len(validate_hat))]
validate_real_sum = [float(torch.sum(validate_real[i][y_days-1])) for i in range(len(validate_real))]
print ("the correlation between validation: ", stats.pearsonr(validate_hat_sum, validate_real_sum)[0])
#test
test_hat_sum = [float(torch.sum(test_hat[i][y_days-1])) for i in range(len(test_hat))]
test_real_sum = [float(torch.sum(test_real[i][y_days-1])) for i in range(len(test_real))]
print ("the correlation between test: ", stats.pearsonr(test_hat_sum, test_real_sum)[0])
#train
train_result, train_hat, train_real = validate_test_process(trained_model, train_x_y)
train_hat_sum = [float(torch.sum(train_hat[i][0])) for i in range(len(train_hat))]
train_real_sum = [float(torch.sum(train_real[i][0])) for i in range(len(train_real))]
print ("the correlation between train: ", stats.pearsonr(train_hat_sum, train_real_sum)[0])
the correlation between validation: 0.8950350848840413 the correlation between test: 0.4827532375518371 the correlation between train: 0.9776392523747837
y1List = [np.sum(list(train_original[i+1][1][Y_day-1].values())) for i in range(len(train_original)-1)]
y2List = [np.sum(list(validate_original[i][1][Y_day-1].values())) for i in range(len(validate_original))]
y2List_hat = [float(torch.sum(validate_hat[i][Y_day-1])) for i in range(len(validate_hat))]
y3List = [np.sum(list(test_original[i][1][Y_day-1].values())) for i in range(len(test_original))]
y3List_hat = [float(torch.sum(test_hat[i][Y_day-1])) for i in range(len(test_hat))]
#x1 = np.array(range(len(y1List)))
#x2 = np.array([len(y1List)+j for j in range(len(y2List))])
x1 = train_list
x2 = validation_list
x3 = np.array([len(y1List)+len(y2List)+j for j in range(len(y3List))])
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x1[0: len(y1List)], np.array(y1List)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='train')
l2 = plt.plot(x2, np.array(y2List)*infection_normalize_ratio, 'go-',linewidth=0.8, markersize=2.0, label='validate')
l3 = plt.plot(x2, np.array(y2List_hat)*infection_normalize_ratio, 'g-',linewidth=2, markersize=0.1, label='validate_predict')
l4 = plt.plot(x3, np.array(y3List)*infection_normalize_ratio, 'bo-',linewidth=0.8, markersize=2, label='test')
l5 = plt.plot(x3, np.array(y3List_hat)*infection_normalize_ratio, 'b-',linewidth=2, markersize=0.1, label='test_predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(5):
summary += 60
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize=8)
plt.yticks(fontsize=12)
plt.title("SpectralGCN")
plt.legend()
plt.grid()
plt.show()
def getPredictionPlot(k):
#location k
x_k = [i for i in range(len(test_real))]
real_k = [test_real[i][y_days-1][k] for i in range(len(test_real))]
predict_k = [test_hat[i][y_days-1][k] for i in range(len(test_hat))]
plt.figure(figsize=(4,2.5), dpi=300)
l1 = plt.plot(x_k, np.array(real_k)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='real',alpha = 0.8)
l2 = plt.plot(x_k, np.array(predict_k)*infection_normalize_ratio, 'o-',color='black',linewidth=0.8, markersize=2.0, alpha = 0.8, label='predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
#plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,100,35)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(3):
summary += 20
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
plt.title("Urban district "+str(k+1))
#plt.legend(fontsize=20)
plt.legend(loc=2,fontsize=15)
#plt.grid()
#plt.savefig('peak3/' + str(k+1) + '.pdf',bbox_inches = 'tight')
plt.show()
for i in range(23):
getPredictionPlot(i)